
#include <winsock2.h>
#include <mstcpip.h>
#include <stdio.h>
#include <process.h>
#include <windows.h>

#include "zmisc_socket.h"


#define LISTEN_MAX_COUNT 10240

#ifndef SO_UPDATE_ACCEPT_CONTEXT
#define SO_UPDATE_ACCEPT_CONTEXT    0x700B
#endif //SO_UPDATE_ACCEPT_CONTEXT

#ifndef ACCEPT_TIME_OUT
#define ACCEPT_TIME_OUT   1000
#endif //ACCEPT_TIME_OUT


#define PRE_ACCEPT_COUNT  128
#define APPEND_ADDR_BUFFER (sizeof(struct sockaddr) + 16 - sizeof(struct sockaddr_in))

extern int iocp_attach( void* p_worker, SOCKET s_socket );

struct accept_over_lapped
{
    OVERLAPPED         st_ovlp;
    SOCKET             s_socket;
    struct sockaddr_in st_local_addr;
    char               ch_append_local[ APPEND_ADDR_BUFFER ];
    struct sockaddr_in st_remote_addr;
    char               ch_append_remote[ APPEND_ADDR_BUFFER ];
};


typedef BOOL (__stdcall *pfn_acceptex)
(
      SOCKET       sListenSocket,
      SOCKET       sAcceptSocket,
      PVOID        lpOutputBuffer,
      DWORD        dwReceiveDataLength,
      DWORD        dwLocalAddressLength,
      DWORD        dwRemoteAddressLength,
      LPDWORD      lpdwBytesReceived,
      LPOVERLAPPED lpOverlapped
);


struct iocp_accept
{
    HANDLE                    h_listen_iocp;
    SOCKET                    s_listen;
    HANDLE                    th_accept_thread;
    long volatile             lv_exit_flag;
    pfn_acceptex              pfn_accept_ex;
    void*                     p_user_data;

    int (*pfn_on_accpet)( void* p_user_data, SOCKET s );
    struct accept_over_lapped st_accept_ovlp[PRE_ACCEPT_COUNT];
};


static int pre_accpet_client( struct iocp_accept* pst_iocp_accept, struct accept_over_lapped* pst_accept_ovlp )
{
    int   i_result   = -1;
    BOOL  b_accpeted = FALSE;
    DWORD dw_len     = 0;

    for ( ; ; )
    {
        memset( pst_accept_ovlp, 0, sizeof(struct accept_over_lapped) );
        pst_accept_ovlp->s_socket = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, (LPWSAPROTOCOL_INFO)0, 0, WSA_FLAG_OVERLAPPED );

        if ( INVALID_SOCKET == pst_accept_ovlp->s_socket )
        {
            printf("get accept socket error\n");
            break;
        }

        //printf( "post socket:%d\n", pst_accept_ovlp->s_socket );

        b_accpeted = pst_iocp_accept->pfn_accept_ex( pst_iocp_accept->s_listen
                                                    , pst_accept_ovlp->s_socket
                                                    , &pst_accept_ovlp->st_local_addr
                                                    , 0
                                                    , sizeof(pst_accept_ovlp->st_local_addr) + 16
                                                    , sizeof(pst_accept_ovlp->st_remote_addr) + 16
                                                    , &dw_len
                                                    , &pst_accept_ovlp->st_ovlp
                                                     );

        if ( FALSE != b_accpeted )
        {
            i_result = 0;
            break;
        }

        if ( ERROR_IO_PENDING == WSAGetLastError() )
        {
            i_result = 0;
            break;
        }

        closesocket( pst_accept_ovlp->s_socket );
        pst_accept_ovlp->s_socket = INVALID_SOCKET;

        //printf("post accept socket error:%u\n", pst_accept_ovlp->s_socket);

        break;
    }

    return i_result;
}


static int update_accept_socket( struct iocp_accept* pst_iocp_accept, struct accept_over_lapped* pst_accept_ovlp )
{
    int                  i_result       = -1;
    int                  i_addr_len     = 0;
    int                  i_option       = 0;
    SOCKET               s_socket_mould = pst_iocp_accept->s_listen;
    SOCKET               s_socket_new   = pst_accept_ovlp->s_socket;
    BOOL                 bKeepAlive     = TRUE;
    struct tcp_keepalive alive_in       = {0};
    struct tcp_keepalive alive_out      = {0};
    DWORD                dw_ret         = 0;

    for ( ; ; )
    {
        if (SOCKET_ERROR == setsockopt( s_socket_new, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, (char *)&s_socket_mould, sizeof(s_socket_mould) ) )
        {
            break;
        }

        i_addr_len = sizeof( pst_accept_ovlp->st_local_addr );
        if (SOCKET_ERROR == getsockname( s_socket_new, (struct sockaddr*)&pst_accept_ovlp->st_local_addr, &i_addr_len ) )
        {
            break;
        }

        i_addr_len = sizeof( pst_accept_ovlp->st_remote_addr );
        if (SOCKET_ERROR == getpeername( s_socket_new, (struct sockaddr*)&pst_accept_ovlp->st_remote_addr, &i_addr_len ) )
        {
            break;
        }

        i_option = 10000;
        if ( SOCKET_ERROR == setsockopt( s_socket_new, SOL_SOCKET, SO_RCVTIMEO,(const char*)&i_option, sizeof(i_option)) )
        {
            break;
        }

        i_option = 10000;
        if ( SOCKET_ERROR == setsockopt( s_socket_new, SOL_SOCKET, SO_SNDTIMEO,(const char*)&i_option, sizeof(i_option)) )
        {
            break;
        }

        i_option = 1;
        if ( SOCKET_ERROR == setsockopt( s_socket_new, IPPROTO_TCP, TCP_NODELAY, (const char*)&i_option, sizeof(i_option)) )
        {
            break;
        }


        if ( SOCKET_ERROR == setsockopt( s_socket_new, SOL_SOCKET, SO_KEEPALIVE, (char*)&bKeepAlive, sizeof(bKeepAlive)) )
        {
            break;
        }

        alive_in.keepalivetime     = 5000;
        alive_in.keepaliveinterval = 1000;
        alive_in.onoff             = TRUE;
        if ( SOCKET_ERROR == WSAIoctl( s_socket_new, SIO_KEEPALIVE_VALS, &alive_in, sizeof(alive_in), &alive_out, sizeof(alive_out), &dw_ret, (LPWSAOVERLAPPED)0, (LPWSAOVERLAPPED_COMPLETION_ROUTINE)0) )
        {
            break;
        }

        i_result = 0;

        break;
    }

    return i_result;
}


static unsigned __stdcall accept_work_thread( void* p_void )
{
    struct iocp_accept*        pst_iocp_accept     = (struct iocp_accept*)p_void;
    DWORD                      dw_number_of_bytes  = 0;
    ULONG_PTR                  ulp_completion_key  = 0;
    struct accept_over_lapped* pst_accept_ovlp     = (struct accept_over_lapped*)0;
    int                        i_accept_result     = -1;
    void*                      p_user_data         = (void*)0;
    int                        i_max_count_idle    = 0;

    if ( (struct iocp_accept*)0 == pst_iocp_accept )
    {
        _endthreadex( (unsigned)-1 );
        return (unsigned)-1;
    }

    p_user_data = pst_iocp_accept->p_user_data;

    for ( ; 0 == pst_iocp_accept->lv_exit_flag; )
    {
        i_accept_result = -1;

        if ( FALSE == GetQueuedCompletionStatus( pst_iocp_accept->h_listen_iocp, &dw_number_of_bytes, &ulp_completion_key, (OVERLAPPED **)&pst_accept_ovlp, ACCEPT_TIME_OUT ) )
        {
            //if ( (struct accept_over_lapped*)0 != pst_accept_ovlp )
            //{
            //    printf( "GetQueuedCompletionStatus error\n" );
            //}
            continue;
        }

        if ( ((struct accept_over_lapped*)0 == pst_accept_ovlp) && ( 0 == ulp_completion_key ) && (0 == dw_number_of_bytes) )
        {
            //printf( "recv exit\n" );
            //fflush(stdout);
            continue;
        }

        if ( 0 != update_accept_socket(pst_iocp_accept, pst_accept_ovlp) )
        {
            printf( "update_accept_socket\n" );
        }

        if ( 0 == pst_iocp_accept->pfn_on_accpet( p_user_data, pst_accept_ovlp->s_socket ) )
        {
            pst_accept_ovlp->s_socket = INVALID_SOCKET;
        }

        if (  INVALID_SOCKET != pst_accept_ovlp->s_socket  )
        {
            closesocket( pst_accept_ovlp->s_socket );
        }
        pst_accept_ovlp->s_socket = INVALID_SOCKET;

        for( ; 0 == pst_iocp_accept->lv_exit_flag; )
        {
            if ( 0 == pre_accpet_client( pst_iocp_accept, pst_accept_ovlp ) )
            {
                break;
            }
            printf( "pre_accpet_client error\n" );
            Sleep(500);
        }
    }
    _endthreadex( 0 );
    return 0;

}


int stop_accept_thread( void * p_accept_handle )
{
    int                i_result     = 0;
    unsigned long      ul_index     = 0;
    struct iocp_accept *pst_accept_handle = (struct iocp_accept *)p_accept_handle;

    for ( ; ; )
    {
        if ( (struct iocp_accept*)0 == pst_accept_handle )
        {
            break;
        }

        pst_accept_handle->lv_exit_flag = -1;

        PostQueuedCompletionStatus( pst_accept_handle->h_listen_iocp, 0, 0, (LPOVERLAPPED)0 );

        if ( 0 < pst_accept_handle->th_accept_thread )
        {
            WaitForSingleObject( pst_accept_handle->th_accept_thread, INFINITE);
            CloseHandle( pst_accept_handle->th_accept_thread );
        }

        if ( INVALID_SOCKET != pst_accept_handle->s_listen )
        {
            closesocket( pst_accept_handle->s_listen );
            pst_accept_handle->s_listen = INVALID_SOCKET;
        }

        if ( (HANDLE)0 != pst_accept_handle->h_listen_iocp )
        {
            CloseHandle( pst_accept_handle->h_listen_iocp );
            pst_accept_handle->h_listen_iocp = (HANDLE)0;
        }

        for ( ul_index = 0; ul_index < PRE_ACCEPT_COUNT; ++ul_index )
        {
            if ( INVALID_SOCKET == pst_accept_handle->st_accept_ovlp[ul_index].s_socket )
            {
                continue;
            }
            closesocket( pst_accept_handle->st_accept_ovlp[ul_index].s_socket );
            pst_accept_handle->st_accept_ovlp[ul_index].s_socket = INVALID_SOCKET;
        }

        free( pst_accept_handle );

        i_result = 0;

        break;
    }

    return i_result;
}


void * start_accept_thread( void* p_user_data, unsigned long ul_ipv4, unsigned short us_listen_port, int (*pfn_on_accpet)( void* p_user_data, SOCKET s ) )
{
    int                   i_result         = -1;
    struct sockaddr_in    st_listen_addr   = {0};
    struct iocp_accept    *pst_iocp_accept = (struct iocp_accept*)0;
    unsigned long         ul_index         = 0;
    unsigned int          ui_thread_id     = 0;

    for ( ; ; )
    {
        pst_iocp_accept = (struct iocp_accept*)malloc( sizeof(struct iocp_accept));
        if ( (struct iocp_accept*)0 == pst_iocp_accept )
        {
            break;
        }
        memset( pst_iocp_accept, 0, sizeof(struct iocp_accept));
        pst_iocp_accept->p_user_data               = p_user_data;
        pst_iocp_accept->pfn_on_accpet             = pfn_on_accpet;

        pst_iocp_accept->h_listen_iocp   = CreateIoCompletionPort(INVALID_HANDLE_VALUE, (HANDLE)0, 0, 0);
        if ( (HANDLE)0 == pst_iocp_accept->h_listen_iocp )
        {
            break;
        }

        pst_iocp_accept->s_listen = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, (LPWSAPROTOCOL_INFO)0, 0, WSA_FLAG_OVERLAPPED );
        if ( INVALID_SOCKET == pst_iocp_accept->s_listen )
        {
            break;
        }

        if ( (HANDLE)0 == CreateIoCompletionPort( (HANDLE)pst_iocp_accept->s_listen, pst_iocp_accept->h_listen_iocp, pst_iocp_accept->s_listen, 0) )
        {
            break;
        }

        if ( SOCKET_ERROR == zmisc_socket_set_reuse_addr( pst_iocp_accept->s_listen ) )
        {
            break;
        }

        if ( SOCKET_ERROR == zmisc_socket_get_accept_ex_func(pst_iocp_accept->s_listen, (void**)&pst_iocp_accept->pfn_accept_ex) )
        {
            break;
        }

        memset(&st_listen_addr, 0, sizeof(st_listen_addr));
        st_listen_addr.sin_family        = AF_INET;
        st_listen_addr.sin_port          = htons(us_listen_port);
        st_listen_addr.sin_addr.s_addr   = ul_ipv4;
        if ( SOCKET_ERROR == bind( pst_iocp_accept->s_listen, (struct sockaddr*)&st_listen_addr, sizeof(st_listen_addr)) )
        {
            break;
        }

        if ( SOCKET_ERROR == listen(pst_iocp_accept->s_listen, LISTEN_MAX_COUNT) )
        {
            break;
        }

        for ( ul_index = 0; ul_index < PRE_ACCEPT_COUNT; ++ul_index )
        {
            if ( pre_accpet_client( pst_iocp_accept, pst_iocp_accept->st_accept_ovlp + ul_index) )
            {
                printf("accept socket error\n");
            }
        }
        pst_iocp_accept->th_accept_thread = (HANDLE)_beginthreadex( (void *)0, 0, accept_work_thread, pst_iocp_accept, 0, &ui_thread_id );
        if ((HANDLE)-1 == pst_iocp_accept->th_accept_thread)
        {
           break;
        }

        i_result = 0;

        break;
    }

    if ( 0 != i_result )
    {
        stop_accept_thread( pst_iocp_accept );
        pst_iocp_accept = (struct iocp_accept*)0;
    }

    return (void *)pst_iocp_accept;
}